package LiuYun

import spinal.core._
import spinal.lib._

object RamConfig{
    var use_vivado = false
}

case class RamInit(w:Int) extends Bundle{
    val cnt   = Reg(UInt((w+1) bits)) init U(0)
    val idx   = cnt(0, w bits)
    val valid = !cnt.msb
    when(valid){
        cnt := cnt + 1
    }
}

class SpBlockRam(val w:Int,val a:Int) extends BlackBox{
    val clka  = in (Bool())
    val ena   = in (Bool())
    val wea   = in (Bool())
    val addra = in (UInt(a bits))
    val dina  = in (Bits(w bits))
    val douta = out(Bits(w bits))
    mapCurrentClockDomain(clka)
}
class SpByteBRam(val w:Int,val a:Int,val bitPerByte:Int) extends BlackBox{
    val n = w * bitPerByte
    val clka  = in (Bool())
    val ena   = in (Bool())
    val wea   = in (Bits(w bits))
    val addra = in (UInt(a bits))
    val dina  = in (Bits(n bits))
    val douta = out(Bits(n bits))
    mapCurrentClockDomain(clka)
}

class inst_bram extends SpBlockRam(32,8)
class tagv_bram extends SpBlockRam(21,8)

class data_bram extends SpByteBRam( 4,8+1,9)

class SpRamCtrl(val w_ram:Int,val w_idx:Int) extends Bundle with IMasterSlave{
    val en = Bool()
    val we = Bool()
    val a  = UInt(w_idx bits)
    val wd = UInt(w_ram bits)
    val rd = UInt(w_ram bits)
    override def asMaster():Unit = {
        out(en,we,a,wd)
        in (rd)
    }
    def connect(ram:SpBlockRam){
        ram.ena   := en
        ram.wea   := we
        ram.addra := a
        ram.dina  := wd.asBits
        rd := ram.douta.asUInt
    }
    def setWriteWhen(cond:Bool){
        en := cond
        we := cond
    }
    def setReadWhen(cond:Bool){
        en := cond
        we := False
    }
    def setWrite(){
        en := True
        we := True
    }
    def setRead(){
        en := True
        we := False
    }
    def setIdle(){
        en := False
        we := False
    }
    def getMem():Mem[UInt] = {
        return Mem(UInt(w_ram bits), 1 << w_idx)
    }
    def connect(mem:Mem[UInt]):Unit = {
        rd := mem.readWriteSync(a,wd,en,we) 
    }
}
class SpRamCtrlWithDirty(w_ram:Int,w_idx:Int) extends SpRamCtrl(w_ram,w_idx){
    val nbyte = w_ram/8
    val wm = Bits(nbyte bits)
    val wb = Bool()
    val rb = Bool()
    override def asMaster():Unit = {
        super.asMaster()
        out(wb,wm)
        in(rb)
    }
    def connect(ram:SpByteBRam){
        ram.ena   := en
        ram.wea   := Mux(we,wm,B(0))
        ram.addra := a
        val wd:Bits = (Range(3,-1,-1).
            map((i:Int)=>this.wb ## this.wd(i*8,8 bits).asBits).
            reduce(_##_)
        )
        ram.dina  := wd.asBits
        val rd:Bits = (Range(3,-1,-1).
            map((i:Int)=>ram.douta(i*9,8 bits)).
            reduce(_ ## _)
        )
        val rb:Bool = (Range(3,-1,-1).
            map((i:Int)=>ram.douta(i*9+8)).
            reduce(_ || _)
        )
        this.rd := rd.asUInt
        this.rb := rb
    }
    override def getMem():Mem[UInt] = {
        import java.lang.RuntimeException
        throw new RuntimeException("Not Implemented getMem, use getArrMem Instead")
    }
    def getArrMem():IndexedSeq[Mem[UInt]] = {
        return Range(0, nbyte).map((i:Int)=>Mem(UInt(9 bits), 1 << w_idx))
    }
    def connect(mem:IndexedSeq[Mem[UInt]]):Unit = {
        val dats = Range(0,nbyte).map((i:Int)=>mem(i).readWriteSync(a,wb.asUInt @@ wd(i*8, 8 bits),en,we && wm(i)))
        for(i <- 0 until 4){
            rd(i*8, 8 bits) := dats(i)(7 downto 0)
        }
        rb := dats.map(_.msb).reduce(_||_)
    }
}

class InstBlockRam extends Component{
    val io = slave(new SpRamCtrl(32,8))
    val ram = if( RamConfig.use_vivado) new inst_bram else null
    val mem = if(!RamConfig.use_vivado) io.getMem() else null
    if(RamConfig.use_vivado)
        io.connect(ram)
    else
        io.connect(mem)

}
class TagVBlockRam extends Component{
    val io = slave(new SpRamCtrl(21,8))
    val ram = if( RamConfig.use_vivado) new tagv_bram else null
    val mem = if(!RamConfig.use_vivado) io.getMem() else null
    if(RamConfig.use_vivado)
        io.connect(ram)
    else
        io.connect(mem)
}
class DataBlockRam extends Component{
    val io = slave(new SpRamCtrlWithDirty(32,9))
    val ram = if( RamConfig.use_vivado) new data_bram else null
    val mem = if(!RamConfig.use_vivado) io.getArrMem() else null
    if(RamConfig.use_vivado)
        io.connect(ram)
    else
        io.connect(mem)
}
